# -*- coding: UTF-8 -*-
"""
    @Author:YTQ
    @Time: 2022/11/28 19:41
    Description:
    
"""
import torch

from modle.TextCNN import TextCNN
from utils.utils import get_label, evaluate
from DataLoad.NewsDataSet import DataSet
from torch.utils import data
import torch.nn as nn
import logging


def test(modelPath, config):
    # 类别
    labels, label_dir = get_label()
    # 数据集
    test_dataset = DataSet(types='test')
    test_loader = data.DataLoader(test_dataset, batch_size=10, shuffle=False)
    # 模型
    model = TextCNN(config=config)
    model.load_state_dict(torch.load(modelPath, map_location=config.device))
    loss_fn = nn.CrossEntropyLoss()

    y_pred = []
    y_true = []
    with torch.no_grad():
        for b, (x, mask, target) in enumerate(test_loader):
            test_pred = model(x, mask)
            loss = loss_fn(test_pred, target)

            logging.info(f'batch:{b}, loss:{round(loss.item(), 5)}')

            test_pred_ = torch.argmax(test_pred, dim=1)

            y_pred += test_pred_.data.tolist()
            y_true += target.data.tolist()

        acc = evaluate(y_pred, y_true, labels)
        logging.info('模型准确率为:')
        logging.info(acc)


"""
round_1: (训练1轮)
               precision    recall  f1-score   support
        金融       0.95      0.88      0.91      1000
        房产       0.97      0.93      0.95      1000
        股票       0.86      0.92      0.89      1000
        教育       0.97      0.95      0.96      1000
        科技       0.89      0.91      0.90      1000
        社会       0.92      0.94      0.93      1000
        政治       0.93      0.91      0.92      1000
        体育       0.98      0.98      0.98      1000
        游戏       0.98      0.93      0.95      1000
        娱乐       0.91      0.98      0.94      1000

    accuracy                           0.93     10000
   macro avg       0.94      0.93      0.93     10000
weighted avg       0.94      0.93      0.93     10000

round_5: (训练5轮)
                precision    recall  f1-score   support

        金融       0.92      0.93      0.92      1000
        房产       0.92      0.97      0.95      1000
        股票       0.89      0.89      0.89      1000
        教育       0.95      0.97      0.96      1000
        科技       0.85      0.92      0.89      1000
        社会       0.95      0.91      0.93      1000
        政治       0.92      0.90      0.91      1000
        体育       0.98      0.98      0.98      1000
        游戏       0.99      0.89      0.94      1000
        娱乐       0.95      0.96      0.96      1000

    accuracy                           0.93     10000
   macro avg       0.93      0.93      0.93     10000
weighted avg       0.93      0.93      0.93     10000

round_15: (训练15轮)
                  precision    recall  f1-score   support

          金融       0.93      0.93      0.93      1000
          房产       0.95      0.95      0.95      1000
          股票       0.89      0.90      0.89      1000
          教育       0.95      0.98      0.96      1000
          科技       0.88      0.91      0.89      1000
          社会       0.93      0.93      0.93      1000
          政治       0.95      0.89      0.92      1000
          体育       0.96      0.99      0.97      1000
          游戏       0.99      0.91      0.95      1000
          娱乐       0.94      0.96      0.95      1000

    accuracy                           0.94     10000
   macro avg       0.94      0.94      0.94     10000
weighted avg       0.94      0.94      0.94     10000

round_15: (训练15轮)
                  precision    recall  f1-score   support

          金融       0.92      0.93      0.93      1000
          房产       0.93      0.97      0.95      1000
          股票       0.90      0.89      0.89      1000
          教育       0.97      0.96      0.97      1000
          科技       0.89      0.91      0.90      1000
          社会       0.90      0.95      0.93      1000
          政治       0.94      0.89      0.92      1000
          体育       0.99      0.98      0.98      1000
          游戏       0.99      0.93      0.96      1000
          娱乐       0.95      0.96      0.96      1000

    accuracy                           0.94     10000
   macro avg       0.94      0.94      0.94     10000
weighted avg       0.94      0.94      0.94     10000
"""


